"""
Phase 10: Financial Structure as Crisis Moderator
==================================================
Tests whether bank-based vs market-based financial structure moderates
demographic crisis vulnerability.

Data:
  - FS.AST.PRVT.GD.ZS  (Domestic credit to private sector, % GDP)
  - CM.MKT.LCAP.GD.ZS  (Stock market capitalization, % GDP)
  - fin_structure = log(bank_credit / stock_mkt_cap)  [Beck-Levine]

Tests:
  1. Financial structure × demographics → crisis onset
  2. Financial structure × demographics → CA reversals
  3. Financial structure × demographics → crisis severity
  4. Extended risk matrix (demo × fin_structure)
"""

import sys, json, time
from pathlib import Path
import pandas as pd
import numpy as np
import requests

sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "multilateral" / "src"))
from model import PanelGLS

DATA_DIR = Path(__file__).resolve().parents[1] / "data" / "processed"
OUTPUT_DIR = Path(__file__).resolve().parents[1] / "output" / "tables"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']


# ── helpers ──────────────────────────────────────────────────────────────

def fmt(b, se, p):
    stars = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
    return f"{b:.4f}{stars}", f"({se:.4f})"

def run_gls(df, dv, ivs):
    """Fit PanelGLS, return dict with coefs/SEs/pvalues/N/R2."""
    cols = [dv] + ivs + ['iso3', 'year']
    sub = df.dropna(subset=[dv] + ivs).copy()
    if len(sub) < 50:
        return None
    m = PanelGLS()
    m.fit(sub[dv].values, sub[ivs].values, sub['iso3'].values, sub['year'].values)
    return {
        'coefs': dict(zip(ivs, m.beta)),
        'ses':   dict(zip(ivs, m.se)),
        'pvals': dict(zip(ivs, m.pvalues)),
        'N': m.n_obs, 'N_c': m.n_countries,
        'R2': m.r_squared, 'rho': m.rho,
    }


def write_table(path, title, panels):
    """Write markdown table. panels = list of (panel_title, col_headers, rows)."""
    lines = [f"## {title}", ""]
    for ptitle, headers, rows in panels:
        if ptitle:
            lines.append(f"**{ptitle}**")
            lines.append("")
        lines.append("| " + " | ".join(headers) + " |")
        lines.append("| " + " | ".join(["---"] * len(headers)) + " |")
        for row in rows:
            lines.append("| " + " | ".join(str(x) for x in row) + " |")
        lines.append("")
    path.write_text("\n".join(lines))
    print(f"  Wrote {path.name}")


def build_regression_rows(results_dict, show_vars, model_names):
    """Build coef + SE rows for a set of models and variables."""
    rows = []
    for var in show_vars:
        coef_row = [var]
        se_row = [""]
        for mname in model_names:
            res = results_dict.get(mname)
            if res is None or var not in res['coefs']:
                coef_row.append("")
                se_row.append("")
            else:
                c, s = fmt(res['coefs'][var], res['ses'][var], res['pvals'][var])
                coef_row.append(c)
                se_row.append(s)
        rows.append(coef_row)
        rows.append(se_row)
    # Add N and R2
    n_row = ["N"]
    r2_row = ["R²"]
    rho_row = ["ρ"]
    for mname in model_names:
        res = results_dict.get(mname)
        if res is None:
            n_row.append(""); r2_row.append(""); rho_row.append("")
        else:
            n_row.append(f"{res['N']:,}")
            r2_row.append(f"{res['R2']:.3f}")
            rho_row.append(f"{res['rho']:.3f}")
    rows.extend([n_row, r2_row, rho_row])
    return rows


# ── data download ────────────────────────────────────────────────────────

def fetch_wdi(indicator, cache_path):
    """Download WDI indicator via World Bank API v2, with cache."""
    if cache_path.exists():
        print(f"  Using cached {cache_path.name}")
        return pd.read_csv(cache_path)

    print(f"  Downloading {indicator} from World Bank API...")
    all_rows = []
    page = 1
    per_page = 1000
    while True:
        url = (f"https://api.worldbank.org/v2/country/all/indicator/{indicator}"
               f"?format=json&per_page={per_page}&page={page}&date=1960:2023")
        resp = requests.get(url, timeout=60)
        resp.raise_for_status()
        data = resp.json()
        if len(data) < 2 or not data[1]:
            break
        for rec in data[1]:
            if rec['value'] is not None:
                all_rows.append({
                    'iso3': rec['countryiso3code'],
                    'year': int(rec['date']),
                    'value': float(rec['value']),
                })
        total_pages = data[0].get('pages', 1)
        if page >= total_pages:
            break
        page += 1
        time.sleep(0.3)

    df = pd.DataFrame(all_rows)
    df.to_csv(cache_path, index=False)
    print(f"  Downloaded {len(df):,} obs for {indicator}")
    return df


def prepare_financial_structure(df):
    """Add financial structure variables to panel."""
    cache_dir = DATA_DIR
    bank = fetch_wdi('FS.AST.PRVT.GD.ZS', cache_dir / 'wdi_bank_credit.csv')
    mkt  = fetch_wdi('CM.MKT.LCAP.GD.ZS', cache_dir / 'wdi_stock_mkt_cap.csv')

    bank = bank.rename(columns={'value': 'bank_credit_gdp'})
    mkt  = mkt.rename(columns={'value': 'stock_mkt_cap_gdp'})

    # Merge
    fin = bank.merge(mkt[['iso3', 'year', 'stock_mkt_cap_gdp']], on=['iso3', 'year'], how='outer')

    # Beck-Levine financial structure
    mask = (fin['bank_credit_gdp'] > 0) & (fin['stock_mkt_cap_gdp'] > 0)
    fin.loc[mask, 'fin_structure'] = np.log(fin.loc[mask, 'bank_credit_gdp'] / fin.loc[mask, 'stock_mkt_cap_gdp'])

    # Binary: bank-based = above median
    median_fs = fin['fin_structure'].median()
    fin['bank_based'] = (fin['fin_structure'] > median_fs).astype(float)
    fin.loc[fin['fin_structure'].isna(), 'bank_based'] = np.nan

    # Merge into panel
    merge_cols = ['iso3', 'year', 'bank_credit_gdp', 'stock_mkt_cap_gdp', 'fin_structure', 'bank_based']
    df = df.merge(fin[merge_cols], on=['iso3', 'year'], how='left')

    # Create interaction terms
    for zv in ['Z_1', 'Z_2', 'Z_3']:
        df[f'{zv}_x_fin'] = df[zv] * df['fin_structure']
    df['old_dep_x_fin'] = df['old_dep'] * df['fin_structure']
    df['youth_dep_x_fin'] = df['youth_dep'] * df['fin_structure']

    n_fs = df['fin_structure'].notna().sum()
    n_bb = df['bank_based'].notna().sum()
    print(f"  Financial structure: {n_fs:,} obs with fin_structure, {n_bb:,} with bank_based")
    print(f"  Median fin_structure = {median_fs:.3f}")

    return df


# ── Test 1: Financial structure × demographics → crisis onset ────────────

def test_crisis_onset(df):
    """Table 1: fin_structure × demographics → banking/any crisis onset."""
    print("\n" + "="*70)
    print("TEST 1: Financial Structure × Demographics → Crisis Onset")
    print("="*70)

    demo = ['Z_1', 'Z_2', 'Z_3']
    controls = [c for c in CONTROLS if c in df.columns]
    interactions = ['Z_1_x_fin', 'Z_2_x_fin', 'Z_3_x_fin']
    age_interactions = ['old_dep_x_fin', 'youth_dep_x_fin']

    results = {}
    for dv in ['banking_crisis_onset', 'any_crisis_onset']:
        # M1: Full sample, Z only
        results[f'{dv}_base'] = run_gls(df, dv, demo + controls)

        # M2: Full sample, Z + fin_structure + Z×fin interactions
        results[f'{dv}_interact'] = run_gls(df, dv, demo + ['fin_structure'] + controls + interactions)

        # M3: Bank-based subsample
        bb = df[df['bank_based'] == 1].copy()
        results[f'{dv}_bank'] = run_gls(bb, dv, demo + controls)

        # M4: Market-based subsample
        mb = df[df['bank_based'] == 0].copy()
        results[f'{dv}_market'] = run_gls(mb, dv, demo + controls)

        # M5: old_dep/youth_dep + fin_structure interactions
        results[f'{dv}_age'] = run_gls(df, dv, ['old_dep', 'youth_dep', 'fin_structure'] + controls + age_interactions)

    # Build table
    show_vars = demo + ['fin_structure'] + interactions
    for dv, label in [('banking_crisis_onset', 'Banking Crisis Onset'),
                       ('any_crisis_onset', 'Any Crisis Onset')]:
        model_names = [f'{dv}_base', f'{dv}_interact', f'{dv}_bank', f'{dv}_market']
        headers = ["", "Full Sample", "+ Interactions", "Bank-Based", "Market-Based"]
        rows = build_regression_rows(
            results,
            demo + ['fin_structure'] + interactions,
            model_names,
        )
        panels = [(label, headers, rows)]

        # Age decomposition panel
        age_vars = ['old_dep', 'youth_dep', 'fin_structure'] + age_interactions
        age_headers = ["", "Age × Fin. Structure"]
        age_rows = []
        res = results.get(f'{dv}_age')
        for var in age_vars:
            if res and var in res['coefs']:
                c, s = fmt(res['coefs'][var], res['ses'][var], res['pvals'][var])
                age_rows.append([var, c])
                age_rows.append(["", s])
        if res:
            age_rows.append(["N", f"{res['N']:,}"])
            age_rows.append(["R²", f"{res['R2']:.3f}"])
        panels.append((f"{label} — Age Decomposition", age_headers, age_rows))

    write_table(OUTPUT_DIR / 'fin_structure_crisis.md',
                "Financial Structure × Demographics → Crisis Onset", panels)
    return results


# ── Test 2: Financial structure × demographics → CA reversals ────────────

def test_ca_reversals(df):
    """Table 2: fin_structure × demographics → CA reversal / sudden stop."""
    print("\n" + "="*70)
    print("TEST 2: Financial Structure × Demographics → CA Reversals")
    print("="*70)

    demo = ['Z_1', 'Z_2', 'Z_3']
    controls = [c for c in CONTROLS if c in df.columns]
    interactions = ['Z_1_x_fin', 'Z_2_x_fin', 'Z_3_x_fin']

    results = {}
    for dv in ['ca_reversal', 'sudden_stop']:
        if dv not in df.columns:
            continue
        results[f'{dv}_base'] = run_gls(df, dv, demo + controls)
        results[f'{dv}_interact'] = run_gls(df, dv, demo + ['fin_structure'] + controls + interactions)

        bb = df[df['bank_based'] == 1].copy()
        results[f'{dv}_bank'] = run_gls(bb, dv, demo + controls)

        mb = df[df['bank_based'] == 0].copy()
        results[f'{dv}_market'] = run_gls(mb, dv, demo + controls)

    panels = []
    for dv, label in [('ca_reversal', 'CA Reversal (≥3pp)'), ('sudden_stop', 'Sudden Stop')]:
        if f'{dv}_base' not in results:
            continue
        model_names = [f'{dv}_base', f'{dv}_interact', f'{dv}_bank', f'{dv}_market']
        headers = ["", "Full Sample", "+ Interactions", "Bank-Based", "Market-Based"]
        rows = build_regression_rows(results, demo + ['fin_structure'] + interactions, model_names)
        panels.append((label, headers, rows))

    write_table(OUTPUT_DIR / 'fin_structure_reversal.md',
                "Financial Structure × Demographics → CA Reversals", panels)
    return results


# ── Test 3: Financial structure × demographics → crisis severity ─────────

def test_severity(df):
    """Table 3: Conditional on crisis — does fin_structure × demographics → deeper output loss?"""
    print("\n" + "="*70)
    print("TEST 3: Financial Structure × Demographics → Crisis Severity")
    print("="*70)

    demo = ['Z_1', 'Z_2', 'Z_3']
    controls_sev = ['nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw']
    controls_sev = [c for c in controls_sev if c in df.columns]

    # Construct cumulative output loss if not present
    if 'cum_output_loss' not in df.columns:
        # Proxy: sum of negative output gap during crisis episode
        if 'output_gap' in df.columns:
            df['cum_output_loss'] = df['output_gap'].clip(upper=0).abs()
        else:
            # Use negative GDP growth as proxy
            df['cum_output_loss'] = df['rgdp_growth'].clip(upper=0).abs()

    # Filter to crisis episodes only
    crisis_df = df[df['any_crisis'] == 1].copy()
    banking_df = df[df['banking_crisis'] == 1].copy()

    results = {}
    interactions = ['Z_1_x_fin', 'Z_2_x_fin', 'Z_3_x_fin']
    age_interactions = ['old_dep_x_fin', 'youth_dep_x_fin']

    for label, sub in [('any_crisis', crisis_df), ('banking_crisis', banking_df)]:
        if len(sub) < 50:
            print(f"  Skipping {label}: only {len(sub)} obs")
            continue

        # M1: demographics only
        results[f'{label}_base'] = run_gls(sub, 'cum_output_loss', demo + controls_sev)

        # M2: + fin_structure + interactions
        results[f'{label}_interact'] = run_gls(sub, 'cum_output_loss',
                                                demo + ['fin_structure'] + controls_sev + interactions)

        # M3: bank-based subsample
        bb = sub[sub['bank_based'] == 1]
        results[f'{label}_bank'] = run_gls(bb, 'cum_output_loss', demo + controls_sev)

        # M4: market-based subsample
        mb = sub[sub['bank_based'] == 0]
        results[f'{label}_market'] = run_gls(mb, 'cum_output_loss', demo + controls_sev)

        # M5: old_dep/youth_dep decomposition
        results[f'{label}_age'] = run_gls(sub, 'cum_output_loss',
                                           ['old_dep', 'youth_dep', 'fin_structure'] + controls_sev + age_interactions)

    panels = []
    for label, ptitle in [('any_crisis', 'Any Crisis'), ('banking_crisis', 'Banking Crisis')]:
        model_names = [f'{label}_base', f'{label}_interact', f'{label}_bank', f'{label}_market']
        headers = ["", "Base", "+ Fin Interactions", "Bank-Based", "Market-Based"]
        rows = build_regression_rows(results, demo + ['fin_structure'] + interactions, model_names)
        panels.append((f"{ptitle} Episodes — Output Loss", headers, rows))

    write_table(OUTPUT_DIR / 'fin_structure_severity.md',
                "Financial Structure × Demographics → Crisis Severity", panels)
    return results


# ── Test 4: Extended risk matrix (demo stage × fin structure) ────────────

def test_risk_matrix(df):
    """Table 4: Risk matrix — demographic stage × financial structure."""
    print("\n" + "="*70)
    print("TEST 4: Risk Matrix — Demo Stage × Financial Structure")
    print("="*70)

    # Create fin_structure tercile
    fs_valid = df['fin_structure'].dropna()
    if len(fs_valid) < 100:
        print("  Insufficient fin_structure data for risk matrix")
        return {}

    t1 = fs_valid.quantile(1/3)
    t2 = fs_valid.quantile(2/3)
    df['fin_tercile'] = pd.cut(df['fin_structure'],
                                bins=[-np.inf, t1, t2, np.inf],
                                labels=['Market', 'Mixed', 'Bank'])

    # Compute demo_tercile on the SAME subsample that has fin_structure
    fs_sub = df.dropna(subset=['fin_structure', 'Z_1'])
    z1_t1 = fs_sub['Z_1'].quantile(1/3)
    z1_t2 = fs_sub['Z_1'].quantile(2/3)
    df['demo_tercile'] = pd.cut(df['Z_1'],
                                 bins=[-np.inf, z1_t1, z1_t2, np.inf],
                                 labels=['Young', 'Mid', 'Old'])

    # Compute cell statistics
    metrics = {
        'banking_crisis_onset': 'Banking Crisis Rate',
        'ca_reversal': 'CA Reversal Rate',
        'any_crisis_onset': 'Any Crisis Rate',
    }

    panels = []
    for metric, metric_label in metrics.items():
        if metric not in df.columns:
            continue

        rows = []
        for ds in ['Young', 'Mid', 'Old']:
            row = [ds]
            for ft in ['Market', 'Mixed', 'Bank']:
                cell = df[(df['demo_tercile'] == ds) & (df['fin_tercile'] == ft)]
                if len(cell) < 5:
                    row.append("—")
                else:
                    rate = cell[metric].mean() * 100
                    n = len(cell)
                    row.append(f"{rate:.1f}% (n={n})")
            rows.append(row)

        headers = ["Demo Stage", "Market-Based", "Mixed", "Bank-Based"]
        panels.append((metric_label, headers, rows))

    # Also add summary statistics per cell
    summ_rows = []
    for ds in ['Young', 'Mid', 'Old']:
        row = [ds]
        for ft in ['Market', 'Mixed', 'Bank']:
            cell = df[(df['demo_tercile'] == ds) & (df['fin_tercile'] == ft)]
            if len(cell) < 20:
                row.append("—")
            else:
                n = len(cell)
                mean_z1 = cell['Z_1'].mean()
                mean_fs = cell['fin_structure'].mean()
                row.append(f"N={n}, Z₁={mean_z1:.2f}, FS={mean_fs:.2f}")
        summ_rows.append(row)

    panels.append(("Cell Descriptives", ["Demo Stage", "Market-Based", "Mixed", "Bank-Based"], summ_rows))

    # 3×3×3 extended matrix: demo × KAOPEN × fin_structure (if KAOPEN available)
    if 'kaopen' in df.columns:
        kaopen_valid = df['kaopen'].dropna()
        kt1 = kaopen_valid.quantile(1/3)
        kt2 = kaopen_valid.quantile(2/3)
        df['kaopen_tercile'] = pd.cut(df['kaopen'],
                                       bins=[-np.inf, kt1, kt2, np.inf],
                                       labels=['Closed', 'Mid', 'Open'])

        ext_rows = []
        for ds in ['Young', 'Mid', 'Old']:
            for kt in ['Closed', 'Mid', 'Open']:
                row = [f"{ds} / {kt}"]
                for ft in ['Market', 'Mixed', 'Bank']:
                    cell = df[(df['demo_tercile'] == ds) &
                              (df['kaopen_tercile'] == kt) &
                              (df['fin_tercile'] == ft)]
                    if len(cell) < 5 or 'banking_crisis_onset' not in df.columns:
                        row.append("—")
                    else:
                        rate = cell['banking_crisis_onset'].mean() * 100
                        row.append(f"{rate:.1f}% (n={len(cell)})")
                ext_rows.append(row)

        panels.append(("Extended: Demo × KAOPEN → Banking Crisis Rate by Fin Structure",
                       ["Demo/KAOPEN", "Market-Based", "Mixed", "Bank-Based"], ext_rows))

    write_table(OUTPUT_DIR / 'fin_structure_risk_matrix.md',
                "Risk Matrix — Demographic Stage × Financial Structure", panels)
    return {'panels': panels}


# ── Save results CSV ─────────────────────────────────────────────────────

def save_results_csv(all_results):
    """Save all regression coefficients to CSV for programmatic access."""
    rows = []
    for test_name, results in all_results.items():
        if not isinstance(results, dict):
            continue
        for model_name, res in results.items():
            if res is None or not isinstance(res, dict) or 'coefs' not in res:
                continue
            for var in res['coefs']:
                rows.append({
                    'test': test_name,
                    'model': model_name,
                    'variable': var,
                    'coefficient': res['coefs'][var],
                    'std_error': res['ses'][var],
                    'p_value': res['pvals'][var],
                    'N': res['N'],
                    'N_countries': res['N_c'],
                    'R2': res['R2'],
                })

    if rows:
        out = pd.DataFrame(rows)
        out.to_csv(OUTPUT_DIR / 'phase10_results.csv', index=False)
        print(f"\n  Saved phase10_results.csv ({len(out)} rows)")


# ── main ─────────────────────────────────────────────────────────────────

def main():
    print("="*70)
    print("PHASE 10: Financial Structure as Crisis Moderator")
    print("="*70)

    # Load panel
    df = pd.read_csv(DATA_DIR / 'crises_panel.csv')
    print(f"  Loaded crises panel: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Add financial structure data
    df = prepare_financial_structure(df)

    # Run tests
    all_results = {}
    all_results['crisis_onset'] = test_crisis_onset(df)
    all_results['ca_reversal'] = test_ca_reversals(df)
    all_results['severity'] = test_severity(df)
    all_results['risk_matrix'] = test_risk_matrix(df)

    # Save combined results
    save_results_csv(all_results)

    print("\n" + "="*70)
    print("Phase 10 complete.")
    print("="*70)


if __name__ == '__main__':
    main()
